from utils import *
from model import *
import torch


def train_gaussian(mus, d, repeat, data, Adj, device, epochs, num_classes, dataset):
    for lidx in range(num_classes):
        Node_labels = np.where((data.y == lidx), -1, 1)
        label = torch.from_numpy(np.where((data.y == lidx), -1, 1)).unsqueeze(dim=1).to(device).float()
        # Node_labels_loss = np.where((data.y == lidx), 0, 1)
        label_loss = torch.from_numpy(np.where((data.y == lidx), 0, 1)).unsqueeze(dim=1).to(device).float()

        results, weights = [], []
        for mu in mus:
            record = np.zeros([repeat, 6])
            mu_torch = torch.from_numpy(mu).float()
            for i in range(repeat):
                print('=' * 20)
                print(f'mu = {mu}')
                X_feature = high_dim_Gaussian(Node_labels, d, mu)
                X_feature_torch = torch.from_numpy(X_feature).to(device)
                X_feature_test = high_dim_Gaussian(Node_labels, d, mu)
                X_feature_torch_test = torch.from_numpy(X_feature_test).to(device)
                # train original acc
                acc_org = (np.sign(X_feature @ mu) == Node_labels).sum() / data.num_nodes
                print(f'acc_org={acc_org}')
                # test original acc
                acc_org_eval = (np.sign(X_feature_test @ mu) == Node_labels).sum() / data.num_nodes

                model_linear = LinearAgg(d).to(device)
                optimizer1 = torch.optim.Adam(model_linear.parameters(), lr=0.01, weight_decay=5e-4)
                criterion1 = nn.BCELoss()

                model_mp = OptimalMP(d).to(device)
                optimizer2 = torch.optim.Adam(model_mp.parameters(), lr=0.01, weight_decay=5e-4)
                criterion2 = nn.BCELoss()

                model_linear.train()
                model_mp.train()
                for epoch in range(epochs):
                    optimizer1.zero_grad()
                    optimizer2.zero_grad()
                    out_linear, out_linear_weight = model_linear(Adj, X_feature_torch)
                    out_mp, out_mp_weight, out_thres = model_mp(Adj, X_feature_torch)
                    predicted_label_linear = torch.sign(out_linear - 0.5)
                    predicted_label_mp = torch.sign(out_mp - 0.5)
                    loss_linear = criterion1(out_linear, label_loss)
                    loss_mp = criterion2(out_mp, label_loss)
                    acc_linear = (predicted_label_linear == label).sum() / data.num_nodes
                    acc_mp = (predicted_label_mp == label).sum() / data.num_nodes
                    loss_linear.backward()
                    loss_mp.backward()
                    optimizer1.step()
                    optimizer2.step()
                    if epoch % 50 == 0:
                        print(f'Epoch:{epoch}')
                        print(
                            f'Linear weight angle:{angle(out_linear_weight.detach().cpu().squeeze(), mu_torch).item()}')
                        print(f'MP weight angle:{angle(out_mp_weight.detach().cpu().squeeze(), mu_torch).item()}')
                        print(f'MP Threshold:{out_thres}')
                        print(f'Linear training loss:{loss_linear}')
                        print(f'MP training loss:{loss_mp}')
                        print(f'acc_linear:{acc_linear}')
                        print(f'acc_mp:{acc_mp}')

                model_linear.eval()
                model_mp.eval()

                out_linear_eval, out_linear_weight_eval = model_linear(Adj, X_feature_torch_test)
                out_mp_eval, out_mp_weight_eval, out_thres_eval = model_mp(Adj, X_feature_torch_test)

                predicted_label_linear_eval = torch.sign(out_linear_eval - 0.5)
                predicted_label_mp_eval = torch.sign(out_mp_eval - 0.5)

                acc_linear_eval = (predicted_label_linear_eval == label).sum() / data.num_nodes
                acc_mp_eval = (predicted_label_mp_eval == label).sum() / data.num_nodes
                print('*' * 20)
                print(f'acc_linear:{acc_linear_eval}')
                print(f'acc_mp:{acc_mp_eval}')
                record[i] = [acc_org, acc_org_eval, acc_mp.item(), acc_mp_eval.item(), acc_linear.item(),
                             acc_linear_eval.item()]
            original_avg, original_avg_eval, mp_avg, mp_avg_eval, linear_avg, linear_avg_eval = record.mean(axis=0)
            original_std, original_std_eval, mp_std, mp_std_eval, linear_std, linear_std_eval = record.std(axis=0)
            results.append([LA.norm(mu_torch).item(),
                            original_avg, mp_avg, linear_avg,
                            original_avg_eval, mp_avg_eval, linear_avg_eval,
                            original_std, original_std_eval, mp_std, mp_std_eval, linear_std, linear_std_eval])
            # weights.append(record[:, -1])
        np.save(f'./results/semi_synthetic_Gau_{dataset}_l{lidx}.npy', results)
        # np.save(f'./results/semi_synthetic_Gau_{dataset}_W_l{lidx}_weights.npy', results)


def train_laplacian(mus, d, repeat, data, Adj, device, epochs, num_classes, dataset):
    for lidx in range(num_classes):
        Node_labels = np.where((data.y == lidx), -1, 1)
        label = torch.from_numpy(np.where((data.y == lidx), -1, 1)).unsqueeze(dim=1).to(device).float()
        # Node_labels_loss = np.where((data.y == lidx), 0, 1)
        label_loss = torch.from_numpy(np.where((data.y == lidx), 0, 1)).unsqueeze(dim=1).to(device).float()

        results, weights = [], []
        for mu in mus:
            record = np.zeros([repeat, 10])
            mu_torch = torch.from_numpy(mu).float()
            for i in range(repeat):
                print('=' * 20)
                print(f'mu = {mu}')
                X_feature = high_dim_Laplace(Node_labels, d, mu)
                X_feature_torch = torch.from_numpy(X_feature).to(device)
                X_feature_test = high_dim_Laplace(Node_labels, d, mu)
                X_feature_torch_test = torch.from_numpy(X_feature_test).to(device)
                # train original acc
                acc_org = (np.sign(X_feature @ mu) == Node_labels).sum() / data.num_nodes
                print(f'acc_org={acc_org}')
                # test original acc
                acc_org_eval = (np.sign(X_feature_test @ mu) == Node_labels).sum() / data.num_nodes

                model_linear = LinearAgg(d).to(device)
                optimizer1 = torch.optim.Adam(model_linear.parameters(), lr=0.01, weight_decay=5e-4)
                criterion1 = nn.BCELoss()

                model_mp = OptimalMP_Laplacian(d).to(device)
                optimizer2 = torch.optim.Adam(model_mp.parameters(), lr=0.01, weight_decay=5e-4)
                criterion2 = nn.BCELoss()

                model_phi = OptimalMP_Laplacian_Phi(d).to(device)
                optimizer3 = torch.optim.Adam(model_phi.parameters(), lr=0.01, weight_decay=5e-4)
                criterion3 = nn.BCELoss()

                model_psi = OptimalMP_Laplacian_Psi(d).to(device)
                optimizer4 = torch.optim.Adam(model_psi.parameters(), lr=0.01, weight_decay=5e-4)
                criterion4 = nn.BCELoss()

                model_linear.train()
                model_mp.train()
                for epoch in range(epochs):
                    optimizer1.zero_grad()
                    optimizer2.zero_grad()
                    optimizer3.zero_grad()
                    optimizer4.zero_grad()
                    out_linear, out_linear_weight = model_linear(Adj, X_feature_torch)
                    out_mp, out_mp_weight, out_thres_1, out_thres_2 = model_mp(Adj, X_feature_torch)
                    out_phi, out_phi_weight, out_thres_phi = model_phi(Adj, X_feature_torch)
                    out_psi, out_psi_weight, out_thres_psi = model_psi(Adj, X_feature_torch)
                    predicted_label_linear = torch.sign(out_linear - 0.5)
                    predicted_label_mp = torch.sign(out_mp - 0.5)
                    predicted_label_phi = torch.sign(out_phi - 0.5)
                    predicted_label_psi = torch.sign(out_psi - 0.5)
                    loss_linear = criterion1(out_linear, label_loss)
                    loss_mp = criterion2(out_mp, label_loss)
                    loss_phi = criterion3(out_phi, label_loss)
                    loss_psi = criterion4(out_psi, label_loss)
                    acc_linear = (predicted_label_linear == label).sum() / data.num_nodes
                    acc_mp = (predicted_label_mp == label).sum() / data.num_nodes
                    acc_phi = (predicted_label_phi == label).sum() / data.num_nodes
                    acc_psi = (predicted_label_psi == label).sum() / data.num_nodes
                    loss_linear.backward()
                    loss_mp.backward()
                    loss_phi.backward()
                    loss_psi.backward()
                    optimizer1.step()
                    optimizer2.step()
                    optimizer3.step()
                    optimizer4.step()
                    if epoch % 50 == 0:
                        print(f'Epoch:{epoch}')
                        print(
                            f'Linear weight angle:{angle(out_linear_weight.detach().cpu().squeeze(), mu_torch).item()}')
                        # print(f'MP Factor:{model_mp.weight.item()}')
                        print(f'MP weight angle:{angle(out_mp_weight.detach().cpu().squeeze(), mu_torch).item()}')
                        print(f'Phi weight angle:{angle(out_phi_weight.detach().cpu().squeeze(), mu_torch).item()}')
                        print(f'Psi weight angle:{angle(out_psi_weight.detach().cpu().squeeze(), mu_torch).item()}')
                        print(f'MP Threshold 1:{out_thres_1}')
                        print(f'MP Threshold 2:{out_thres_2}')
                        print(f'Threshold Phi:{out_thres_phi}')
                        print(f'Threshold Psi:{out_thres_psi}')
                        print(f'Linear training loss:{loss_linear}')
                        print(f'MP training loss:{loss_mp}')
                        print(f'Phi training loss:{loss_phi}')
                        print(f'Psi training loss:{loss_psi}')
                        print(f'acc_linear:{acc_linear}')
                        print(f'acc_mp:{acc_mp}')
                        print(f'acc_phi:{acc_phi}')
                        print(f'acc_psi:{acc_psi}')

                model_linear.eval()
                model_mp.eval()
                model_phi.eval()
                model_psi.eval()

                out_linear_eval, out_linear_weight_eval = model_linear(Adj, X_feature_torch_test)
                out_mp_eval, out_mp_weight_eval, out_thres_eval_1, out_thres_eval_2 = model_mp(Adj,
                                                                                               X_feature_torch_test)
                out_phi_eval, out_phi_weight_eval, out_thres_eval_phi = model_phi(Adj, X_feature_torch_test)
                out_psi_eval, out_psi_weight_eval, out_thres_eval_psi = model_psi(Adj, X_feature_torch_test)

                predicted_label_linear_eval = torch.sign(out_linear_eval - 0.5)
                predicted_label_mp_eval = torch.sign(out_mp_eval - 0.5)
                predicted_label_phi_eval = torch.sign(out_phi_eval - 0.5)
                predicted_label_psi_eval = torch.sign(out_psi_eval - 0.5)

                acc_linear_eval = (predicted_label_linear_eval == label).sum() / data.num_nodes
                acc_mp_eval = (predicted_label_mp_eval == label).sum() / data.num_nodes
                acc_phi_eval = (predicted_label_phi_eval == label).sum() / data.num_nodes
                acc_psi_eval = (predicted_label_psi_eval == label).sum() / data.num_nodes
                print('*' * 20)
                print(f'acc_linear:{acc_linear_eval}')
                print(f'acc_mp:{acc_mp_eval}')
                print(f'acc_phi:{acc_phi_eval}')
                print(f'acc_psi:{acc_psi_eval}')

                record[i] = [acc_org, acc_org_eval, acc_mp.item(), acc_mp_eval.item(), acc_linear.item(),
                             acc_linear_eval.item(), acc_phi.item(), acc_phi_eval.item(), acc_psi.item(),
                             acc_psi_eval.item()]
            original_avg, original_avg_eval, mp_avg, mp_avg_eval, linear_avg, linear_avg_eval, phi_avg, phi_avg_eval, psi_avg, psi_avg_eval = record.mean(
                axis=0)
            original_std, original_std_eval, mp_std, mp_std_eval, linear_std, linear_std_eval, phi_std, phi_std_eval, psi_std, psi_std_eval = record.std(
                axis=0)
            results.append(
                [LA.norm(torch.from_numpy(mu).float()).numpy(), original_avg, mp_avg, linear_avg, phi_avg, psi_avg,
                 original_avg_eval,
                 mp_avg_eval, linear_avg_eval, phi_avg_eval, psi_avg_eval, original_std, original_std_eval, mp_std,
                 mp_std_eval, linear_std,
                 linear_std_eval, phi_std, phi_std_eval, psi_std, psi_std_eval])
            # weights.append(record[:, -1])
        np.save(f'./results/semi_synthetic_LP_{dataset}_l{lidx}.npy', results)
        # np.save(f'./results/semi_synthetic_LP_{dataset}_W_l{lidx}_weights.npy', results)
